{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Replay & Fork state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State History using SDK"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Redefine financial advicer graph with intent checker (from 11_dynamic-breakpoints.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import yfinance as yf\n",
    "from pprint import pformat\n",
    "from langchain_core.tools import Tool\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "from langgraph.graph import MessagesState, START, StateGraph\n",
    "from langgraph.prebuilt import tools_condition, ToolNode\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from IPython.display import Image, display\n",
    "from langgraph.errors import NodeInterrupt\n",
    "\n",
    "\n",
    "# Defining Tools\n",
    "##################################################################################\n",
    "\n",
    "def lookup_stock_symbol(company_name: str) -> str:\n",
    "    \"\"\"\n",
    "    Converts a company name to its stock symbol using a financial API.\n",
    "\n",
    "    Parameters:\n",
    "        company_name (str): The full company name (e.g., 'Tesla').\n",
    "\n",
    "    Returns:\n",
    "        str: The stock symbol (e.g., 'TSLA') or an error message.\n",
    "    \"\"\"\n",
    "    api_url = \"https://www.alphavantage.co/query\"\n",
    "    params = {\n",
    "        \"function\": \"SYMBOL_SEARCH\",\n",
    "        \"keywords\": company_name,\n",
    "        \"apikey\": \"your_alphavantage_api_key_2\"\n",
    "    }\n",
    "    \n",
    "    response = requests.get(api_url, params=params)\n",
    "    data = response.json()\n",
    "    \n",
    "    if \"bestMatches\" in data and data[\"bestMatches\"]:\n",
    "        return data[\"bestMatches\"][0][\"1. symbol\"]\n",
    "    else:\n",
    "        return f\"Symbol not found for {company_name}.\"\n",
    "\n",
    "\n",
    "def fetch_stock_data_raw(stock_symbol: str) -> dict:\n",
    "    \"\"\"\n",
    "    Fetches comprehensive stock data for a given symbol and returns it as a combined dictionary.\n",
    "\n",
    "    Parameters:\n",
    "        stock_symbol (str): The stock ticker symbol (e.g., 'TSLA').\n",
    "        period (str): The period to analyze (e.g., '1mo', '3mo', '1y').\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary combining general stock info and historical market data.\n",
    "    \"\"\"\n",
    "    period = \"1mo\"\n",
    "    try:\n",
    "        stock = yf.Ticker(stock_symbol)\n",
    "\n",
    "        # Retrieve general stock info and historical market data\n",
    "        stock_info = stock.info  # Basic company and stock data\n",
    "        stock_history = stock.history(period=period).to_dict()  # Historical OHLCV data\n",
    "\n",
    "        # Combine both into a single dictionary\n",
    "        combined_data = {\n",
    "            \"stock_symbol\": stock_symbol,\n",
    "            \"info\": stock_info,\n",
    "            \"history\": stock_history\n",
    "        }\n",
    "\n",
    "        return pformat(combined_data)\n",
    "\n",
    "    except Exception as e:\n",
    "        return {\"error\": f\"Error fetching stock data for {stock_symbol}: {str(e)}\"}\n",
    "\n",
    "\n",
    "# Binding tools to the LLM\n",
    "##################################################################################\n",
    "\n",
    "# Create tool bindings with additional attributes\n",
    "lookup_stock = Tool.from_function(\n",
    "    func=lookup_stock_symbol,\n",
    "    name=\"lookup_stock_symbol\",\n",
    "    description=\"Converts a company name to its stock symbol using a financial API.\",\n",
    "    return_direct=False  # Return result to be processed by LLM\n",
    ")\n",
    "\n",
    "fetch_stock = Tool.from_function(\n",
    "    func=fetch_stock_data_raw,\n",
    "    name=\"fetch_stock_data_raw\",\n",
    "    description=\"Fetches comprehensive stock data including general info and historical market data for a given stock symbol.\",\n",
    "    return_direct=False\n",
    ")\n",
    "\n",
    "toolbox = [lookup_stock, fetch_stock]\n",
    "\n",
    "# OPENAI_API_KEY environment variable must be set\n",
    "simple_llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "llm_with_tools = simple_llm.bind_tools(toolbox)\n",
    "\n",
    "\n",
    "# Defining Agent's node\n",
    "##################################################################################\n",
    "\n",
    "# System message\n",
    "assistant_system_message = SystemMessage(content=(\"\"\"\n",
    "You are a professional financial assistant specializing in stock market analysis and investment strategies. \n",
    "Your role is to analyze stock data and provide **clear, decisive recommendations** that users can act on, \n",
    "whether they already hold the stock or are considering investing.\n",
    "\n",
    "You have access to a set of tools that can provide the data you need to analyze stocks effectively. \n",
    "Use these tools to gather relevant information such as stock symbols, current prices, historical trends, \n",
    "and key financial indicators. Your goal is to leverage these resources efficiently to generate accurate, \n",
    "actionable insights for the user.\n",
    "\n",
    "Your responses should be:\n",
    "- **Concise and direct**, summarizing only the most critical insights.\n",
    "- **Actionable**, offering clear guidance on whether to buy, sell, hold, or wait for better opportunities.\n",
    "- **Context-aware**, considering both current holders and potential investors.\n",
    "- **Free of speculation**, relying solely on factual data and trends.\n",
    "- **do not forget** to provide stock name in the report, so it's clear which stock is being recommended.\n",
    "\n",
    "### Response Format:\n",
    "1. **Recommendation:** Buy, Sell, Hold, or Wait.\n",
    "2. **Key Insights:** Highlight critical trends and market factors that influence the decision.\n",
    "3. **Suggested Next Steps:** What the user should do based on their current position.\n",
    "\n",
    "If the user does not specify whether they own the stock, provide recommendations for both potential buyers and current holders. Ensure your advice considers valuation, trends, and market sentiment.\n",
    "\n",
    "Your goal is to help users make informed financial decisions quickly and confidently.\n",
    "\"\"\"))\n",
    "\n",
    "# Node\n",
    "def assistant(state: MessagesState):\n",
    "   return {\"messages\": [llm_with_tools.invoke([assistant_system_message] + state[\"messages\"])]}\n",
    "\n",
    "def intent_check(state: MessagesState):\n",
    "    user_request = state[\"messages\"][-1].content\n",
    "\n",
    "    financial_check_prompt = f\"\"\"\n",
    "    You are an intent classifier. Your task is to determine if the user's request is specifically related to finance, investments, or financial advice.\n",
    "\n",
    "    Evaluate the following user request:\n",
    "\n",
    "    \"{user_request}\"\n",
    "\n",
    "    If the request is about finance, investments, or financial advice, respond with \"True\".\n",
    "    If it is unrelated to finance, respond with \"False\".\n",
    "\n",
    "    Respond with only \"True\" or \"False\" and nothing else.\n",
    "    \"\"\"\n",
    "\n",
    "    llm_response = llm_with_tools.invoke([HumanMessage(content=financial_check_prompt)]).content\n",
    "    is_financial_question = llm_response.strip().lower() == 'true'\n",
    "\n",
    "    if not is_financial_question:\n",
    "        raise NodeInterrupt(\"Please ask a question related to financial advice.\")\n",
    "\n",
    "    return state\n",
    "\n",
    "\n",
    "# Defining Graph\n",
    "##################################################################################\n",
    "builder = StateGraph(MessagesState)\n",
    "\n",
    "# Define nodes: these do the work\n",
    "builder.add_node(\"intent_check\", intent_check)\n",
    "builder.add_node(\"assistant\", assistant)\n",
    "builder.add_node(\"tools\", ToolNode(toolbox))\n",
    "\n",
    "# Define edges: these determine how the control flow moves\n",
    "builder.add_edge(START, \"intent_check\")\n",
    "builder.add_edge(\"intent_check\", \"assistant\")\n",
    "builder.add_conditional_edges(\n",
    "    \"assistant\",\n",
    "    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools\n",
    "    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END\n",
    "    tools_condition,\n",
    ")\n",
    "builder.add_edge(\"tools\", \"assistant\")\n",
    "\n",
    "memory = MemorySaver()\n",
    "graph = builder.compile(checkpointer=memory)\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Execute the graph and interrupt it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start a new conversation\n",
    "thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
    "\n",
    "# define intiial user request\n",
    "initial_input = {\"messages\": HumanMessage(content=\"What is the weather outside?\")}\n",
    "\n",
    "# run the graph and stream in values mode\n",
    "for event in graph.stream(initial_input, thread, stream_mode=\"values\"):\n",
    "    event['messages'][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Return current checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = graph.get_state(thread)\n",
    "print(state.next)\n",
    "print(\"\\n\\n\")\n",
    "print(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the most recent Chekpoint\n",
    "![get-most-recent-checkpoint](images/get-most-recent-checkpoint.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's finish the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start a new conversation\n",
    "thread = {\"configurable\": {\"thread_id\": \"2\"}}\n",
    "\n",
    "# define intiial user request\n",
    "initial_input = {\"messages\": HumanMessage(content=\"What is a stock symbol for Tesla?\")}\n",
    "\n",
    "# run the graph and stream in values mode\n",
    "for event in graph.stream(initial_input, thread, stream_mode=\"values\"):\n",
    "    event['messages'][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = graph.get_state(thread)\n",
    "print(state.next)\n",
    "print(\"\\n\\n\")\n",
    "print(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple Checkpoints\n",
    "![multiple-checkpoints](images/multiple-checkpoints.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_states = [s for s in graph.get_state_history(thread)]\n",
    "len(all_states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Replay the Graph\n",
    "\n",
    "Let's replay the graph from assistant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_replay = all_states[-3]\n",
    "to_replay.values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check next Node to run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_replay.next"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check config with `checkpoint_id` and `thread_id`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_replay.config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and finally let's reply!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for event in graph.stream(None, to_replay.config, stream_mode=\"values\"):\n",
    "    event['messages'][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important!**\n",
    "\n",
    "Replay does not actually trigger real execution!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forking with SDK\n",
    "![forking](images/forking.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_fork = all_states[-3]\n",
    "to_fork.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(to_fork.next)\n",
    "print(to_fork.config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Modify specific checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fork_config = graph.update_state(to_fork.config, {\n",
    "    \"messages\": [HumanMessage(content='What is a stock symbol for nvidia?', id=to_fork.values[\"messages\"][0].id)]},\n",
    ")\n",
    "print(fork_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "State was updated but next node still the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = graph.get_state(thread)\n",
    "print(state.next)\n",
    "print(\"\\n\\n\")\n",
    "print(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Execute the fork (reak work instead of simple replay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for event in graph.stream(None, fork_config, stream_mode=\"values\"):\n",
    "    event['messages'][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = graph.get_state(thread)\n",
    "print(state.next)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State History using API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph_sdk import get_client\n",
    "\n",
    "URL = \"http://localhost:57698\"\n",
    "client = get_client(url=URL)\n",
    "\n",
    "assistants = await client.assistants.search()\n",
    "assistants"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Replay with API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import convert_to_messages\n",
    "\n",
    "thread = await client.threads.create()\n",
    "input_message = HumanMessage(content=\"What is a stock symbol for Tesla?\")\n",
    "\n",
    "async for event in client.runs.stream(\n",
    "            thread[\"thread_id\"], \n",
    "            assistant_id=\"financial_advisor_intent_check\",\n",
    "            input={\"messages\": [input_message]}, \n",
    "            stream_mode=\"values\",\n",
    "):\n",
    "    messages = event.data.get('messages', None)\n",
    "    if messages:\n",
    "        print(convert_to_messages(messages)[-1])\n",
    "        print(\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Get the `checkpoint_id` from state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "states = await client.threads.get_history(thread['thread_id'])\n",
    "print(len(states))\n",
    "to_replay = states[-5]\n",
    "to_replay"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Replay the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async for event in client.runs.stream(\n",
    "            thread[\"thread_id\"], \n",
    "            assistant_id=\"financial_advisor_intent_check\",\n",
    "            input=None, \n",
    "            stream_mode=\"values\",\n",
    "            checkpoint_id=to_replay['checkpoint_id']\n",
    "):\n",
    "    messages = event.data.get('messages', None)\n",
    "    if messages:\n",
    "        print(f\"\\n\\n{convert_to_messages(messages)[-1]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forking with API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_fork = states[-3]\n",
    "print(to_fork['values'])\n",
    "print(\"\\n\\n\")\n",
    "print(to_fork['next'])\n",
    "print(\"\\n\\n\")\n",
    "print(to_fork['checkpoint_id'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fork the checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forked_config = await client.threads.update_state(\n",
    "    thread[\"thread_id\"],\n",
    "    {\"messages\": HumanMessage(\n",
    "        content=\"What is a stock symbol for nvidia?\", \n",
    "        id=to_fork['values']['messages'][0]['id'])},\n",
    "    checkpoint_id=to_fork['checkpoint_id']\n",
    ")\n",
    "print(forked_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Resume the graph with forked checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async for event in client.runs.stream(\n",
    "            thread[\"thread_id\"], \n",
    "            assistant_id=\"financial_advisor_intent_check\",\n",
    "            input=None, \n",
    "            stream_mode=\"values\",\n",
    "            checkpoint_id=forked_config['checkpoint_id']\n",
    "):\n",
    "    messages = event.data.get('messages', None)\n",
    "    if messages:\n",
    "        print(f\"\\n\\n{convert_to_messages(messages)[-1]}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
